import torch
from torch.nn.parallel import DistributedDataParallel
from kantts.models.hifigan.hifigan import (  # NOQA
    Generator,  # NOQA
    MultiScaleDiscriminator,  # NOQA
    MultiPeriodDiscriminator,  # NOQA
    MultiSpecDiscriminator,  # NOQA
)
import kantts
import kantts.train.scheduler
from kantts.models.sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT  # NOQA
from kantts.utils.ling_unit.ling_unit import get_fpdict
from .pqmf import PQMF


def optimizer_builder(model_params, opt_name, opt_params):
    opt_cls = getattr(torch.optim, opt_name)
    optimizer = opt_cls(model_params, **opt_params)
    return optimizer


def scheduler_builder(optimizer, sche_name, sche_params):
    scheduler_cls = getattr(kantts.train.scheduler, sche_name)
    scheduler = scheduler_cls(optimizer, **sche_params)
    return scheduler


def hifigan_model_builder(config, device, rank, distributed):
    model = {}
    optimizer = {}
    scheduler = {}
    model["discriminator"] = {}
    optimizer["discriminator"] = {}
    scheduler["discriminator"] = {}
    for model_name in config["Model"].keys():
        if model_name == "Generator":
            params = config["Model"][model_name]["params"]
            model["generator"] = Generator(**params).to(device)
            optimizer["generator"] = optimizer_builder(
                model["generator"].parameters(),
                config["Model"][model_name]["optimizer"].get("type", "Adam"),
                config["Model"][model_name]["optimizer"].get("params", {}),
            )
            scheduler["generator"] = scheduler_builder(
                optimizer["generator"],
                config["Model"][model_name]["scheduler"].get("type", "StepLR"),
                config["Model"][model_name]["scheduler"].get("params", {}),
            )
        else:
            params = config["Model"][model_name]["params"]
            model["discriminator"][model_name] = globals()[model_name](**params).to(
                device
            )
            optimizer["discriminator"][model_name] = optimizer_builder(
                model["discriminator"][model_name].parameters(),
                config["Model"][model_name]["optimizer"].get("type", "Adam"),
                config["Model"][model_name]["optimizer"].get("params", {}),
            )
            scheduler["discriminator"][model_name] = scheduler_builder(
                optimizer["discriminator"][model_name],
                config["Model"][model_name]["scheduler"].get("type", "StepLR"),
                config["Model"][model_name]["scheduler"].get("params", {}),
            )

    out_channels = config["Model"]["Generator"]["params"]["out_channels"]
    if out_channels > 1:
        model["pqmf"] = PQMF(subbands=out_channels, **config.get("pqmf", {})).to(device)

    # FIXME: pywavelets buffer leads to gradient error in DDP training
    # Solution: https://github.com/pytorch/pytorch/issues/22095
    if distributed:
        model["generator"] = DistributedDataParallel(
            model["generator"],
            device_ids=[rank],
            output_device=rank,
            broadcast_buffers=False,
        )
        for model_name in model["discriminator"].keys():
            model["discriminator"][model_name] = DistributedDataParallel(
                model["discriminator"][model_name],
                device_ids=[rank],
                output_device=rank,
                broadcast_buffers=False,
            )

    return model, optimizer, scheduler


#  TODO: some parsing
def sambert_model_builder(config, device, rank, distributed):
    model = {}
    optimizer = {}
    scheduler = {}

    model["KanTtsSAMBERT"] = KanTtsSAMBERT(
        config["Model"]["KanTtsSAMBERT"]["params"]
    ).to(device)

    fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
    if fp_enable:
        fp_dict = {
            k: torch.from_numpy(v).long().unsqueeze(0).to(device)
            for k, v in get_fpdict(config).items()
        }
        model["KanTtsSAMBERT"].fp_dict = fp_dict

    optimizer["KanTtsSAMBERT"] = optimizer_builder(
        model["KanTtsSAMBERT"].parameters(),
        config["Model"]["KanTtsSAMBERT"]["optimizer"].get("type", "Adam"),
        config["Model"]["KanTtsSAMBERT"]["optimizer"].get("params", {}),
    )
    scheduler["KanTtsSAMBERT"] = scheduler_builder(
        optimizer["KanTtsSAMBERT"],
        config["Model"]["KanTtsSAMBERT"]["scheduler"].get("type", "StepLR"),
        config["Model"]["KanTtsSAMBERT"]["scheduler"].get("params", {}),
    )

    if distributed:
        model["KanTtsSAMBERT"] = DistributedDataParallel(
            model["KanTtsSAMBERT"], device_ids=[rank], output_device=rank
        )

    return model, optimizer, scheduler


def sybert_model_builder(config, device, rank, distributed):
    model = {}
    optimizer = {}
    scheduler = {}

    model["KanTtsTextsyBERT"] = KanTtsTextsyBERT(
        config["Model"]["KanTtsTextsyBERT"]["params"]
    ).to(device)
    optimizer["KanTtsTextsyBERT"] = optimizer_builder(
        model["KanTtsTextsyBERT"].parameters(),
        config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("type", "Adam"),
        config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("params", {}),
    )
    scheduler["KanTtsTextsyBERT"] = scheduler_builder(
        optimizer["KanTtsTextsyBERT"],
        config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("type", "StepLR"),
        config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("params", {}),
    )

    if distributed:
        model["KanTtsTextsyBERT"] = DistributedDataParallel(
            model["KanTtsTextsyBERT"], device_ids=[rank], output_device=rank
        )

    return model, optimizer, scheduler


#  TODO: implement a builder for specific model
model_dict = {
    "hifigan": hifigan_model_builder,
    "sambert": sambert_model_builder,
    "sybert": sybert_model_builder,
}


def model_builder(config, device="cpu", rank=0, distributed=False):
    builder_func = model_dict[config["model_type"]]
    model, optimizer, scheduler = builder_func(config, device, rank, distributed)
    return model, optimizer, scheduler
